# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""st for scipy.ops_wrapper."""
import pytest
import mindspore.scipy as msp
from mindspore import context, Tensor
from tests.st.scipy_st.utils import match_array

aligndict = {0: "LEFT_RIGHT", 1: "LEFT_LEFT", 2: "RIGHT_LEFT", 3: "RIGHT_RIGHT"}
PAD_VALUE = -1


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('array_dict', [([[[5]]], {}),
                                        ([[[3, 1, 1], [6, 4, 4], [1, 6, 4]]],
                                         {(-2, -2, 0): [[1]], (-2, -1, 3): [[[6, 6], [-1, 1]]],
                                          (-2, 0, 2): [[[3, 4, 4], [6, 6, -1], [1, -1, -1]]],
                                          (-2, 1, 3): [[[-1, 1, 4], [3, 4, 4], [-1, 6, 6], [-1, -1, 1]]],
                                          (-2, 2, 0): [[[1, -1, -1], [1, 4, -1], [3, 4, 4], [-1, 6, 6], [-1, -1, 1]]],
                                          (-1, -1, 2): [[6, 6]], (-1, 0, 1): [[[3, 4, 4], [6, 6, -1]]],
                                          (-1, 1, 2): [[[-1, 1, 4], [3, 4, 4], [6, 6, -1]]],
                                          (-1, 2, 3): [[[-1, -1, 1], [-1, 1, 4], [3, 4, 4], [-1, 6, 6]]],
                                          (0, 0, 0): [[3, 4, 4]], (0, 1, 1): [[[1, 4, -1], [3, 4, 4]]],
                                          (0, 2, 2): [[[-1, -1, 1], [-1, 1, 4], [3, 4, 4]]], (1, 1, 2): [[1, 4]],
                                          (1, 2, 3): [[[-1, 1], [1, 4]]]}),
                                        ([[[6, 1]]], {}),
                                        ([[[2, 2, 4, 3, 0], [8, 5, 3, 0, 3], [6, 3, 2, 6, 7]]],
                                         {(-2, -2, 0): [[6]], (-2, -1, 3): [[[8, 3], [-1, 6]]],
                                          (-2, 0, 2): [[[2, 5, 2], [8, 3, -1], [6, -1, -1]]],
                                          (-2, 1, 3): [[[2, 3, 6], [2, 5, 2], [-1, 8, 3], [-1, -1, 6]]],
                                          (-2, 2, 0): [[[4, 0, 7], [2, 3, 6], [2, 5, 2], [-1, 8, 3], [-1, -1, 6]]],
                                          (-2, 3, 1): [
                                              [[3, 3, -1], [4, 0, 7], [2, 3, 6], [2, 5, 2], [8, 3, -1], [6, -1, -1]]],
                                          (-2, 4, 2): [
                                              [[-1, -1, 0], [-1, 3, 3], [4, 0, 7], [2, 3, 6], [2, 5, 2], [8, 3, -1],
                                               [6, -1, -1]]], (-1, -1, 2): [[8, 3]],
                                          (-1, 0, 1): [[[2, 5, 2], [8, 3, -1]]],
                                          (-1, 1, 2): [[[2, 3, 6], [2, 5, 2], [8, 3, -1]]],
                                          (-1, 2, 3): [[[4, 0, 7], [2, 3, 6], [2, 5, 2], [-1, 8, 3]]],
                                          (-1, 3, 0): [[[3, 3, -1], [4, 0, 7], [2, 3, 6], [2, 5, 2], [-1, 8, 3]]],
                                          (-1, 4, 1): [
                                              [[0, -1, -1], [3, 3, -1], [4, 0, 7], [2, 3, 6], [2, 5, 2], [8, 3, -1]]],
                                          (0, 0, 0): [[2, 5, 2]], (0, 1, 1): [[[2, 3, 6], [2, 5, 2]]],
                                          (0, 2, 2): [[[4, 0, 7], [2, 3, 6], [2, 5, 2]]],
                                          (0, 3, 3): [[[-1, 3, 3], [4, 0, 7], [2, 3, 6], [2, 5, 2]]],
                                          (0, 4, 0): [[[0, -1, -1], [3, 3, -1], [4, 0, 7], [2, 3, 6], [2, 5, 2]]],
                                          (1, 1, 2): [[2, 3, 6]], (1, 2, 3): [[[4, 0, 7], [2, 3, 6]]],
                                          (1, 3, 0): [[[3, 3, -1], [4, 0, 7], [2, 3, 6]]],
                                          (1, 4, 1): [[[0, -1, -1], [3, 3, -1], [4, 0, 7], [2, 3, 6]]]}),
                                        ([[[5], [5]]], {(-1, -1, 2): [[5]], (-1, 0, 1): [[[5], [5]]],
                                                        (0, 0, 0): [[5]]}),
                                        ([[[2, 4, 1], [6, 4, 1], [0, 5, 2], [1, 6, 0], [1, 0, 7]]],
                                         {(-4, -4, 0): [[1]], (-4, -3, 3): [[[1, 0], [-1, 1]]],
                                          (-4, -2, 2): [[[0, 6, 7], [1, 0, -1], [1, -1, -1]]],
                                          (-4, -1, 1): [[[6, 5, 0], [0, 6, 7], [1, 0, -1], [1, -1, -1]]],
                                          (-4, 0, 0): [[[2, 4, 2], [6, 5, 0], [0, 6, 7], [-1, 1, 0], [-1, -1, 1]]],
                                          (-4, 1, 1): [
                                              [[4, 1, -1], [2, 4, 2], [6, 5, 0], [0, 6, 7], [1, 0, -1], [1, -1, -1]]],
                                          (-4, 2, 2): [
                                              [[-1, -1, 1], [-1, 4, 1], [2, 4, 2], [6, 5, 0], [0, 6, 7], [1, 0, -1],
                                               [1, -1, -1]]], (-3, -3, 2): [[1, 0]],
                                          (-3, -2, 1): [[[0, 6, 7], [1, 0, -1]]],
                                          (-3, -1, 0): [[[6, 5, 0], [0, 6, 7], [-1, 1, 0]]],
                                          (-3, 0, 3): [[[2, 4, 2], [6, 5, 0], [0, 6, 7], [-1, 1, 0]]],
                                          (-3, 1, 0): [[[4, 1, -1], [2, 4, 2], [6, 5, 0], [0, 6, 7], [-1, 1, 0]]],
                                          (-3, 2, 1): [
                                              [[1, -1, -1], [4, 1, -1], [2, 4, 2], [6, 5, 0], [0, 6, 7], [1, 0, -1]]],
                                          (-2, -2, 0): [[0, 6, 7]], (-2, -1, 3): [[[6, 5, 0], [0, 6, 7]]],
                                          (-2, 0, 2): [[[2, 4, 2], [6, 5, 0], [0, 6, 7]]],
                                          (-2, 1, 3): [[[-1, 4, 1], [2, 4, 2], [6, 5, 0], [0, 6, 7]]],
                                          (-2, 2, 0): [[[1, -1, -1], [4, 1, -1], [2, 4, 2], [6, 5, 0], [0, 6, 7]]],
                                          (-1, -1, 2): [[6, 5, 0]], (-1, 0, 1): [[[2, 4, 2], [6, 5, 0]]],
                                          (-1, 1, 2): [[[-1, 4, 1], [2, 4, 2], [6, 5, 0]]],
                                          (-1, 2, 3): [[[-1, -1, 1], [-1, 4, 1], [2, 4, 2], [6, 5, 0]]],
                                          (0, 0, 0): [[2, 4, 2]], (0, 1, 1): [[[4, 1, -1], [2, 4, 2]]],
                                          (0, 2, 2): [[[-1, -1, 1], [-1, 4, 1], [2, 4, 2]]], (1, 1, 2): [[4, 1]],
                                          (1, 2, 3): [[[-1, 1], [4, 1]]], (2, 2, 0): [[1]]}),
                                        ([[[6]], [[4]]], {}),
                                        ([[[2, 4, 8], [3, 4, 2], [1, 6, 3]], [[6, 7, 2], [8, 2, 1], [4, 5, 5]]],
                                         {(-2, -2, 0): [[1], [4]], (-2, -1, 3): [[[3, 6], [-1, 1]], [[8, 5], [-1, 4]]],
                                          (-2, 0, 2): [[[2, 4, 3], [3, 6, -1], [1, -1, -1]],
                                                       [[6, 2, 5], [8, 5, -1], [4, -1, -1]]],
                                          (-2, 1, 3): [[[-1, 4, 2], [2, 4, 3], [-1, 3, 6], [-1, -1, 1]],
                                                       [[-1, 7, 1], [6, 2, 5], [-1, 8, 5], [-1, -1, 4]]],
                                          (-2, 2, 0): [[[8, -1, -1], [4, 2, -1], [2, 4, 3], [-1, 3, 6], [-1, -1, 1]],
                                                       [[2, -1, -1], [7, 1, -1], [6, 2, 5], [-1, 8, 5], [-1, -1, 4]]],
                                          (-1, -1, 2): [[3, 6], [8, 5]],
                                          (-1, 0, 1): [[[2, 4, 3], [3, 6, -1]], [[6, 2, 5], [8, 5, -1]]],
                                          (-1, 1, 2): [[[-1, 4, 2], [2, 4, 3], [3, 6, -1]],
                                                       [[-1, 7, 1], [6, 2, 5], [8, 5, -1]]],
                                          (-1, 2, 3): [[[-1, -1, 8], [-1, 4, 2], [2, 4, 3], [-1, 3, 6]],
                                                       [[-1, -1, 2], [-1, 7, 1], [6, 2, 5], [-1, 8, 5]]],
                                          (0, 0, 0): [[2, 4, 3], [6, 2, 5]],
                                          (0, 1, 1): [[[4, 2, -1], [2, 4, 3]], [[7, 1, -1], [6, 2, 5]]],
                                          (0, 2, 2): [[[-1, -1, 8], [-1, 4, 2], [2, 4, 3]],
                                                      [[-1, -1, 2], [-1, 7, 1], [6, 2, 5]]],
                                          (1, 1, 2): [[4, 2], [7, 1]],
                                          (1, 2, 3): [[[-1, 8], [4, 2]], [[-1, 2], [7, 1]]]}),
                                        ([[[4, 0]], [[7, 4]]], {}),
                                        ([[[3, 5, 8, 3, 5], [7, 8, 1, 0, 6], [5, 4, 0, 3, 6]],
                                          [[7, 4, 8, 7, 3], [4, 6, 5, 7, 1], [5, 3, 1, 1, 0]]],
                                         {(-2, -2, 0): [[5], [5]], (-2, -1, 3): [[[7, 4], [-1, 5]], [[4, 3], [-1, 5]]],
                                          (-2, 0, 2): [[[3, 8, 0], [7, 4, -1], [5, -1, -1]],
                                                       [[7, 6, 1], [4, 3, -1], [5, -1, -1]]],
                                          (-2, 1, 3): [[[5, 1, 3], [3, 8, 0], [-1, 7, 4], [-1, -1, 5]],
                                                       [[4, 5, 1], [7, 6, 1], [-1, 4, 3], [-1, -1, 5]]],
                                          (-2, 2, 0): [[[8, 0, 6], [5, 1, 3], [3, 8, 0], [-1, 7, 4], [-1, -1, 5]],
                                                       [[8, 7, 0], [4, 5, 1], [7, 6, 1], [-1, 4, 3], [-1, -1, 5]]],
                                          (-2, 3, 1): [
                                              [[3, 6, -1], [8, 0, 6], [5, 1, 3], [3, 8, 0], [7, 4, -1], [5, -1, -1]],
                                              [[7, 1, -1], [8, 7, 0], [4, 5, 1], [7, 6, 1], [4, 3, -1], [5, -1, -1]]],
                                          (-2, 4, 2): [
                                              [[-1, -1, 5], [-1, 3, 6], [8, 0, 6], [5, 1, 3], [3, 8, 0], [7, 4, -1],
                                               [5, -1, -1]],
                                              [[-1, -1, 3], [-1, 7, 1], [8, 7, 0], [4, 5, 1], [7, 6, 1], [4, 3, -1],
                                               [5, -1, -1]]],
                                          (-1, -1, 2): [[7, 4], [4, 3]],
                                          (-1, 0, 1): [[[3, 8, 0], [7, 4, -1]], [[7, 6, 1], [4, 3, -1]]],
                                          (-1, 1, 2): [[[5, 1, 3], [3, 8, 0], [7, 4, -1]],
                                                       [[4, 5, 1], [7, 6, 1], [4, 3, -1]]],
                                          (-1, 2, 3): [[[8, 0, 6], [5, 1, 3], [3, 8, 0], [-1, 7, 4]],
                                                       [[8, 7, 0], [4, 5, 1], [7, 6, 1], [-1, 4, 3]]],
                                          (-1, 3, 0): [[[3, 6, -1], [8, 0, 6], [5, 1, 3], [3, 8, 0], [-1, 7, 4]],
                                                       [[7, 1, -1], [8, 7, 0], [4, 5, 1], [7, 6, 1], [-1, 4, 3]]],
                                          (-1, 4, 1): [
                                              [[5, -1, -1], [3, 6, -1], [8, 0, 6], [5, 1, 3], [3, 8, 0], [7, 4, -1]],
                                              [[3, -1, -1], [7, 1, -1], [8, 7, 0], [4, 5, 1], [7, 6, 1], [4, 3, -1]]],
                                          (0, 0, 0): [[3, 8, 0], [7, 6, 1]],
                                          (0, 1, 1): [[[5, 1, 3], [3, 8, 0]], [[4, 5, 1], [7, 6, 1]]],
                                          (0, 2, 2): [[[8, 0, 6], [5, 1, 3], [3, 8, 0]],
                                                      [[8, 7, 0], [4, 5, 1], [7, 6, 1]]],
                                          (0, 3, 3): [[[-1, 3, 6], [8, 0, 6], [5, 1, 3], [3, 8, 0]],
                                                      [[-1, 7, 1], [8, 7, 0], [4, 5, 1], [7, 6, 1]]],
                                          (0, 4, 0): [[[5, -1, -1], [3, 6, -1], [8, 0, 6], [5, 1, 3], [3, 8, 0]],
                                                      [[3, -1, -1], [7, 1, -1], [8, 7, 0], [4, 5, 1], [7, 6, 1]]],
                                          (1, 1, 2): [[5, 1, 3], [4, 5, 1]],
                                          (1, 2, 3): [[[8, 0, 6], [5, 1, 3]], [[8, 7, 0], [4, 5, 1]]],
                                          (1, 3, 0): [[[3, 6, -1], [8, 0, 6], [5, 1, 3]],
                                                      [[7, 1, -1], [8, 7, 0], [4, 5, 1]]],
                                          (1, 4, 1): [[[5, -1, -1], [3, 6, -1], [8, 0, 6], [5, 1, 3]],
                                                      [[3, -1, -1], [7, 1, -1], [8, 7, 0], [4, 5, 1]]]}),
                                        ([[[4], [7]], [[3], [5]]],
                                         {(-1, -1, 2): [[7], [5]], (-1, 0, 1): [[[4], [7]], [[3], [5]]],
                                          (0, 0, 0): [[4], [3]]}),
                                        ([[[0, 2, 2], [0, 0, 5], [6, 5, 5], [5, 8, 5], [3, 8, 0]],
                                          [[2, 8, 3], [4, 4, 1], [0, 4, 2], [0, 7, 0], [0, 7, 4]]],
                                         {(-4, -4, 0): [[3], [0]], (-4, -3, 3): [[[5, 8], [-1, 3]], [[0, 7], [-1, 0]]],
                                          (-4, -2, 2): [[[6, 8, 0], [5, 8, -1], [3, -1, -1]],
                                                        [[0, 7, 4], [0, 7, -1], [0, -1, -1]]],
                                          (-4, -1, 1): [[[0, 5, 5], [6, 8, 0], [5, 8, -1], [3, -1, -1]],
                                                        [[4, 4, 0], [0, 7, 4], [0, 7, -1], [0, -1, -1]]],
                                          (-4, 0, 0): [[[0, 0, 5], [0, 5, 5], [6, 8, 0], [-1, 5, 8], [-1, -1, 3]],
                                                       [[2, 4, 2], [4, 4, 0], [0, 7, 4], [-1, 0, 7], [-1, -1, 0]]],
                                          (-4, 1, 1): [
                                              [[2, 5, -1], [0, 0, 5], [0, 5, 5], [6, 8, 0], [5, 8, -1], [3, -1, -1]],
                                              [[8, 1, -1], [2, 4, 2], [4, 4, 0], [0, 7, 4], [0, 7, -1], [0, -1, -1]]],
                                          (-4, 2, 2): [
                                              [[-1, -1, 2], [-1, 2, 5], [0, 0, 5], [0, 5, 5], [6, 8, 0], [5, 8, -1],
                                               [3, -1, -1]],
                                              [[-1, -1, 3], [-1, 8, 1], [2, 4, 2], [4, 4, 0], [0, 7, 4], [0, 7, -1],
                                               [0, -1, -1]]], (-3, -3, 2): [[5, 8], [0, 7]],
                                          (-3, -2, 1): [[[6, 8, 0], [5, 8, -1]], [[0, 7, 4], [0, 7, -1]]],
                                          (-3, -1, 0): [[[0, 5, 5], [6, 8, 0], [-1, 5, 8]],
                                                        [[4, 4, 0], [0, 7, 4], [-1, 0, 7]]],
                                          (-3, 0, 3): [[[0, 0, 5], [0, 5, 5], [6, 8, 0], [-1, 5, 8]],
                                                       [[2, 4, 2], [4, 4, 0], [0, 7, 4], [-1, 0, 7]]],
                                          (-3, 1, 0): [[[2, 5, -1], [0, 0, 5], [0, 5, 5], [6, 8, 0], [-1, 5, 8]],
                                                       [[8, 1, -1], [2, 4, 2], [4, 4, 0], [0, 7, 4], [-1, 0, 7]]],
                                          (-3, 2, 1): [
                                              [[2, -1, -1], [2, 5, -1], [0, 0, 5], [0, 5, 5], [6, 8, 0], [5, 8, -1]],
                                              [[3, -1, -1], [8, 1, -1], [2, 4, 2], [4, 4, 0], [0, 7, 4], [0, 7, -1]]],
                                          (-2, -2, 0): [[6, 8, 0], [0, 7, 4]],
                                          (-2, -1, 3): [[[0, 5, 5], [6, 8, 0]], [[4, 4, 0], [0, 7, 4]]],
                                          (-2, 0, 2): [[[0, 0, 5], [0, 5, 5], [6, 8, 0]],
                                                       [[2, 4, 2], [4, 4, 0], [0, 7, 4]]],
                                          (-2, 1, 3): [[[-1, 2, 5], [0, 0, 5], [0, 5, 5], [6, 8, 0]],
                                                       [[-1, 8, 1], [2, 4, 2], [4, 4, 0], [0, 7, 4]]],
                                          (-2, 2, 0): [[[2, -1, -1], [2, 5, -1], [0, 0, 5], [0, 5, 5], [6, 8, 0]],
                                                       [[3, -1, -1], [8, 1, -1], [2, 4, 2], [4, 4, 0], [0, 7, 4]]],
                                          (-1, -1, 2): [[0, 5, 5], [4, 4, 0]],
                                          (-1, 0, 1): [[[0, 0, 5], [0, 5, 5]], [[2, 4, 2], [4, 4, 0]]],
                                          (-1, 1, 2): [[[-1, 2, 5], [0, 0, 5], [0, 5, 5]],
                                                       [[-1, 8, 1], [2, 4, 2], [4, 4, 0]]],
                                          (-1, 2, 3): [[[-1, -1, 2], [-1, 2, 5], [0, 0, 5], [0, 5, 5]],
                                                       [[-1, -1, 3], [-1, 8, 1], [2, 4, 2], [4, 4, 0]]],
                                          (0, 0, 0): [[0, 0, 5], [2, 4, 2]],
                                          (0, 1, 1): [[[2, 5, -1], [0, 0, 5]], [[8, 1, -1], [2, 4, 2]]],
                                          (0, 2, 2): [[[-1, -1, 2], [-1, 2, 5], [0, 0, 5]],
                                                      [[-1, -1, 3], [-1, 8, 1], [2, 4, 2]]],
                                          (1, 1, 2): [[2, 5], [8, 1]],
                                          (1, 2, 3): [[[-1, 2], [2, 5]], [[-1, 3], [8, 1]]], (2, 2, 0): [[2], [3]]})])
def test_matrix_diag_part_net(array_dict):
    """
    testcase generate from below
    from tf.python.ops import array_ops
    import numpy as np
    f = open (r'dict.tst','w')
    aligndict = {0: "LEFT_RIGHT", 1:"LEFT_LEFT", 2:"RIGHT_LEFT", 3:"RIGHT_RIGHT"}
    Adict=[]
    for i in [1, 2]:
        for m,n in [(1, 1), (3,3),(1, 2),(3, 5),(2, 1),(5, 3)]:
            A = np.array(np.random.randint(20, size=(i, m, n)))
            kadict={}
            for k0 in range(-m + 1, m - 1):
                for k1 in range(k0, n):
                    k = (k0, k1)
                    align_= (abs(k0)+ abs(k1)) % 4
                    ka = (k,align_)
                    B = array_ops.matrix_diag_part(A, k=k, align=aligndict[align_], padding_value=-1)
                    kadict[ka] = B.numpy()
            Adict.append(A, kadict)
    print(Adict, file= f)
    f.close()
    Feature: ALL To ALL
    Description: test cases for eigen decomposition test cases for Ax= lambda * x /( A- lambda * E)X=0
    Expectation: the result match to numpy
    """
    context.set_context(mode=context.PYNATIVE_MODE)
    a, kadict = array_dict
    for key1, b in kadict.items():
        k0, k1, align_ = key1
        if k0 == k1:
            r_b = msp.ops_wrapper.matrix_diag_part(Tensor(a), k0, PAD_VALUE, align=aligndict[align_])
        else:
            r_b = msp.ops_wrapper.matrix_diag_part(Tensor(a), (k0, k1), PAD_VALUE, align=aligndict[align_])
            match_array(b, r_b.asnumpy())
